{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## **Fully Connected Neural Network with FFT**\n",
    "- input shape: 320\n",
    "- library: tensorflow\n",
    "- Training dataset: Basic dataset or CDD-based dataset\n",
    "- Test dataset: Basic, CDD-based and external\n",
    "- Metric: Accuracy, F1-score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.preprocessing import normalize\n",
    "from sklearn import metrics\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.optimizers import Adam\n",
    "from tensorflow.keras.models import load_model\n",
    "import keras\n",
    "from keras.layers import Dense, ReLU, BatchNormalization,  Dropout\n",
    "from keras.models import Sequential, load_model\n",
    "from matplotlib import pyplot as plt\n",
    "import seaborn as sns\n",
    "from sklearn.utils import shuffle\n",
    "import time\n",
    "import datetime\n",
    "from numpy.fft import fft\n",
    "#import tensorflow_addons as tfa"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## **Definition of reusable functions**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eliminate_transient(df):\n",
    "    no_transient = df[(df[0] == 1) | (df[0] == -1)]\n",
    "    no_transient.loc[no_transient[0] == -1, 0] = 0\n",
    "    num_pos = no_transient[0].value_counts()[1]\n",
    "    num_neg = no_transient[0].value_counts()[0]\n",
    "    return no_transient.to_numpy(), num_pos, num_neg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "########## Dealing with Nan #############################\n",
    "# know bug, cannot deal with nan in consecutive np.nan value\n",
    "def replaceNan(original_array):\n",
    "    an_np_array = np.copy(original_array)\n",
    "    if not np.isnan(an_np_array).any():\n",
    "        return an_np_array\n",
    "    else:\n",
    "        for index in np.argwhere(np.isnan(an_np_array)):\n",
    "            if index[1] == 0:\n",
    "                an_np_array[index[0],index[1]] = an_np_array[index[0],index[1]+1]\n",
    "            else:\n",
    "                an_np_array[index[0],index[1]] = an_np_array[index[0],index[1]-1]\n",
    "    return an_np_array"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Plot the data\n",
    "def linearlize_and_plot(df):\n",
    "    data_one_raw = []\n",
    "    for row in df:\n",
    "        data_one_raw = np.append(data_one_raw,row)\n",
    "    return data_one_raw"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "#### expand label with certain replication, in arc detection case, the ratio is 160\n",
    "def label_expansion(df,expansion_ratio):\n",
    "    label_expansion = []\n",
    "    for row in df:\n",
    "        label_expansion = np.append(label_expansion,np.full(expansion_ratio,row))\n",
    "    return label_expansion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_data_and_label(data, label, save_figure = False, filename=\"\", title=\"\"):\n",
    "    fig, ax1 = plt.subplots()\n",
    "    color = 'tab:red'\n",
    "    ax1.set_xlabel('datapoint')\n",
    "    ax1.set_ylabel('current', color='blue')\n",
    "    ax1.plot(linearlize_and_plot(data), color='blue')\n",
    "    ax1.plot(label_expansion(label/100,160), color='red')\n",
    "    ax1.tick_params(axis='y', labelcolor='blue')\n",
    "    ax1.set_title(title)\n",
    "    if save_figure:\n",
    "        fig.savefig(filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 0.  1. nan  2.]\n",
      " [nan nan  4.  5.]]\n"
     ]
    }
   ],
   "source": [
    "array = np.array([[0,1,np.nan,2],[np.nan,np.nan,4,5]])\n",
    "replaceNan(array)\n",
    "print(array)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# convert dataframe to numpy array, split the dataset into data and label, generate train, valid and test dataset at the same time\n",
    "def shuffle_normalise_split_fft_dataset(original_array, train_ratio, valid_ratio, test_ratio,shuffle_data=True):\n",
    "    np_array = replaceNan(original_array)\n",
    "    np_data = np.copy(np_array[:,1:])\n",
    "    np_label = np.copy(np_array[:,0:1])\n",
    "    zero_no_current_status(np_data,0.005)\n",
    "\n",
    "    np_data_fft = np.abs(fft(np_data)) \n",
    "\n",
    "    np_data_fft_normalised = normalize(np.abs(fft(np_data_fft)), axis=1, norm='l1')\n",
    "    np_data_normalized = normalize(np_data, axis=1, norm='l1')\n",
    "\n",
    "    np_data_and_fft = np.concatenate((np_label,np_data_normalized,np_data_fft_normalised),axis=1)\n",
    "    if shuffle:\n",
    "        np_shuffle = shuffle(np_data_and_fft)\n",
    "    else:\n",
    "        np_shuffle = np_data_and_fft\n",
    "    # split the data\n",
    "    train_index = int(len(np_shuffle) * train_ratio)       #length of the train data\n",
    "    trainset = np_shuffle[0:train_index]                      \n",
    "    valid_index = int(len(np_shuffle) * valid_ratio)                         #length of the valid_set\n",
    "    validset = np_shuffle[train_index : train_index+valid_index]            #the valid_dataframe from the pandas dataframe\n",
    "    testset = np_shuffle[train_index+valid_index : ]\n",
    "# output original data, and preprocessed and splited data\n",
    "    return  np_data_and_fft[:,1:], np_label, trainset[:,1:], trainset[:,0], validset[:,1:], validset[:,0], testset[:,1:], testset[:,0]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[1 4 5 6]\n",
      " [0 7 8 9]]\n"
     ]
    }
   ],
   "source": [
    "# Create two sample NumPy arrays\n",
    "array1 = np.array([[1],[0]])\n",
    "array2 = np.array([[4, 5, 6],[7,8,9]])\n",
    "\n",
    "# Concatenate the arrays along the specified axis (0 for rows, 1 for columns)\n",
    "concatenated_array = np.concatenate((array1, array2),axis=1)\n",
    "\n",
    "print(concatenated_array)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def zero_no_current_status(np_array, threshold):\n",
    "    np_array[np_array<threshold] = threshold"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.332 0.332]\n",
      " [0.001 0.001]]\n"
     ]
    }
   ],
   "source": [
    "array = np.array([[0.332,0.332],[0.0001,0.0001]])\n",
    "zero_no_current_status(array,0.001)\n",
    "print(array)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def result_log (model=\"model type\", test=\"test dataset\", filename = \"evaluation_log.txt\", cf_matrix = None, accuracy_score = None, f1_score = None):\n",
    "    ts = time.time()\n",
    "    sttime = datetime.datetime.fromtimestamp(ts).strftime('%Y%m%d_%H:%M:%S - ')\n",
    "    with open(filename,'a') as log_file:\n",
    "        log_file.write(sttime +'\\n')\n",
    "        log_file.write(\"Result of \" + model + \" on \" + test +'\\n')\n",
    "        log_file.write(f\"Accuracy: {accuracy_score:.4f}\" + '\\n')\n",
    "        log_file.write(f\"F1 score: {f1_score:.4f}\" + '\\n')\n",
    "        for line in cf_matrix:\n",
    "            log_file.write(\" \".join([str(n) for n in line]) + \"\\n\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluation(model,test_set,test_label,test_set_name,filename,model_name):\n",
    "    predict_probability = model.predict(\n",
    "        test_set,\n",
    "        batch_size = 64\n",
    "    )\n",
    "    predict_label = np.where(predict_probability > 0.5, 1, 0)\n",
    "    cf_matrix = metrics.confusion_matrix(test_label, predict_label)\n",
    "    accuracy_score = metrics.accuracy_score(test_label, predict_label)\n",
    "    f1score = metrics.f1_score(test_label, predict_label)\n",
    "    # save results\n",
    "    \n",
    "    result_log(model=model_name, test=test_set_name, cf_matrix=cf_matrix, accuracy_score = accuracy_score,f1_score = f1score,filename=filename)\n",
    "    # save plot as proof of training procedure so that no overfitting \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_the_train_history(train_val_history, filename):\n",
    "    fig, ax = plt.subplots()\n",
    "    ax.plot(train_val_history.history['binary_accuracy'])\n",
    "    ax.plot(train_val_history.history['val_binary_accuracy'])\n",
    "    ax.set_title('model accuracy')\n",
    "    ax.set_ylabel('accuracy')\n",
    "    ax.set_xlabel('epoch')\n",
    "    ax.legend(['train', 'val'], loc='upper left')\n",
    "    fig.savefig(filename)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Starting Point of the program"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Read data and preprocess\n",
    "1. delete transient part\n",
    "2. shuffle dastaset and split them into train, valid and test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "M_df = pd.read_csv(\"../data/Basic_dataset.csv\", header=None)\n",
    "MV_df = pd.read_csv(\"../data/CDD_based_dataset.csv\", header=None)\n",
    "MF_df = pd.read_csv(\"../data/External_dataset.csv\", header=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Basic Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "#----------------------------------------------------model------------------------------------------------------\n",
    "def fnn_training_with_fft(train_set, train_label, valid_set, valid_label,epoch):\n",
    "\n",
    "    epochs = epoch\n",
    "    batch_size = 64\n",
    "    dropout = 0.25\n",
    "    l2 = tf.keras.regularizers.L2(l2=1e-5)\n",
    "    num_pos = len(train_label[train_label==1])\n",
    "    num_neg = len(train_label[train_label==0])\n",
    "    total_samples = num_pos + num_neg\n",
    "    output_bias = np.log([num_pos / num_neg])                 # for the output layer (imbalance data)\n",
    "    weight_class_0 = (1 / num_neg) * (total_samples / 2.0)    # weights for the loss function (imbalance data)\n",
    "    weight_class_1 = (1 / num_pos) * (total_samples / 2.0)\n",
    "    class_weights = {0: weight_class_0,\n",
    "                    1: weight_class_1}\n",
    "    initializer = tf.keras.initializers.GlorotNormal()\n",
    "\n",
    "    ################################# architecture\n",
    "    model = Sequential([\n",
    "        Dense(640, kernel_regularizer=l2, use_bias=True, kernel_initializer=initializer,input_shape=(320,)),\n",
    "        BatchNormalization(),\n",
    "        ReLU(),\n",
    "        Dropout(dropout),\n",
    "\n",
    "        Dense(1280, kernel_regularizer=l2, use_bias=True, kernel_initializer=initializer),\n",
    "        BatchNormalization(),\n",
    "        ReLU(),\n",
    "        Dropout(dropout),\n",
    "\n",
    "        Dense(480, kernel_regularizer=l2, use_bias=True, kernel_initializer=initializer),\n",
    "        BatchNormalization(),\n",
    "        ReLU(),\n",
    "        Dropout(dropout),\n",
    "\n",
    "        Dense(1, activation='sigmoid'),\n",
    "    ])\n",
    "\n",
    "    #############################  compiling\n",
    "    metric_accuracy = keras.metrics.BinaryAccuracy()\n",
    "    metric_f1score = tf.keras.metrics.F1Score(threshold=0.5)\n",
    "    #metric_f1score = tfa.metrics.F1Score(num_classes=1, threshold=0.5, average='micro')\n",
    "    model.compile(\n",
    "        optimizer= Adam(learning_rate= 1e-6),\n",
    "        loss= keras.losses.BinaryCrossentropy(),\n",
    "        metrics= [metric_accuracy]\n",
    "    )\n",
    "\n",
    "    ############################## training\n",
    "    train_val_history = model.fit(\n",
    "        train_set,\n",
    "        train_label,\n",
    "        batch_size=batch_size,\n",
    "        epochs=epochs,\n",
    "        validation_data=(valid_set, valid_label),\n",
    "        shuffle=True,\n",
    "        class_weight=class_weights\n",
    "    )\n",
    "    return model, train_val_history\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\marti\\.conda\\envs\\causality_4\\lib\\site-packages\\keras\\src\\initializers\\initializers.py:120: UserWarning: The initializer GlorotNormal is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initializer instance more than once.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/20\n",
      " 72/279 [======>.......................] - ETA: 4s - loss: 0.6708 - binary_accuracy: 0.6200"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[17], line 23\u001b[0m\n\u001b[0;32m     21\u001b[0m \u001b[38;5;66;03m# train the model\u001b[39;00m\n\u001b[0;32m     22\u001b[0m training_epoch \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m20\u001b[39m\n\u001b[1;32m---> 23\u001b[0m fnn_basic_with_fft_model, train_val_history \u001b[38;5;241m=\u001b[39m \u001b[43mfnn_training_with_fft\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_set\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_label\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalid_set\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalid_label\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepoch\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mtraining_epoch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     24\u001b[0m \u001b[38;5;66;03m# evaluation\u001b[39;00m\n\u001b[0;32m     25\u001b[0m log_filename \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfnn_with_fft.txt\u001b[39m\u001b[38;5;124m\"\u001b[39m\n",
      "Cell \u001b[1;32mIn[16], line 49\u001b[0m, in \u001b[0;36mfnn_training_with_fft\u001b[1;34m(train_set, train_label, valid_set, valid_label, epoch)\u001b[0m\n\u001b[0;32m     42\u001b[0m model\u001b[38;5;241m.\u001b[39mcompile(\n\u001b[0;32m     43\u001b[0m     optimizer\u001b[38;5;241m=\u001b[39m Adam(learning_rate\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1e-6\u001b[39m),\n\u001b[0;32m     44\u001b[0m     loss\u001b[38;5;241m=\u001b[39m keras\u001b[38;5;241m.\u001b[39mlosses\u001b[38;5;241m.\u001b[39mBinaryCrossentropy(),\n\u001b[0;32m     45\u001b[0m     metrics\u001b[38;5;241m=\u001b[39m [metric_accuracy]\n\u001b[0;32m     46\u001b[0m )\n\u001b[0;32m     48\u001b[0m \u001b[38;5;66;03m############################## training\u001b[39;00m\n\u001b[1;32m---> 49\u001b[0m train_val_history \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m     50\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtrain_set\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m     51\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtrain_label\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m     52\u001b[0m \u001b[43m    \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m     53\u001b[0m \u001b[43m    \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mepochs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m     54\u001b[0m \u001b[43m    \u001b[49m\u001b[43mvalidation_data\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mvalid_set\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalid_label\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m     55\u001b[0m \u001b[43m    \u001b[49m\u001b[43mshuffle\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m     56\u001b[0m \u001b[43m    \u001b[49m\u001b[43mclass_weight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclass_weights\u001b[49m\n\u001b[0;32m     57\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     58\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m model, train_val_history\n",
      "File \u001b[1;32mc:\\Users\\marti\\.conda\\envs\\causality_4\\lib\\site-packages\\keras\\src\\utils\\traceback_utils.py:65\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m     63\u001b[0m filtered_tb \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m     64\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m---> 65\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m fn(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m     66\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m     67\u001b[0m     filtered_tb \u001b[38;5;241m=\u001b[39m _process_traceback_frames(e\u001b[38;5;241m.\u001b[39m__traceback__)\n",
      "File \u001b[1;32mc:\\Users\\marti\\.conda\\envs\\causality_4\\lib\\site-packages\\keras\\src\\engine\\training.py:1742\u001b[0m, in \u001b[0;36mModel.fit\u001b[1;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)\u001b[0m\n\u001b[0;32m   1734\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m tf\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mexperimental\u001b[38;5;241m.\u001b[39mTrace(\n\u001b[0;32m   1735\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrain\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m   1736\u001b[0m     epoch_num\u001b[38;5;241m=\u001b[39mepoch,\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m   1739\u001b[0m     _r\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m,\n\u001b[0;32m   1740\u001b[0m ):\n\u001b[0;32m   1741\u001b[0m     callbacks\u001b[38;5;241m.\u001b[39mon_train_batch_begin(step)\n\u001b[1;32m-> 1742\u001b[0m     tmp_logs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_function\u001b[49m\u001b[43m(\u001b[49m\u001b[43miterator\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1743\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m data_handler\u001b[38;5;241m.\u001b[39mshould_sync:\n\u001b[0;32m   1744\u001b[0m         context\u001b[38;5;241m.\u001b[39masync_wait()\n",
      "File \u001b[1;32mc:\\Users\\marti\\.conda\\envs\\causality_4\\lib\\site-packages\\tensorflow\\python\\util\\traceback_utils.py:150\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m    148\u001b[0m filtered_tb \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m    149\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 150\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m fn(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m    151\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m    152\u001b[0m   filtered_tb \u001b[38;5;241m=\u001b[39m _process_traceback_frames(e\u001b[38;5;241m.\u001b[39m__traceback__)\n",
      "File \u001b[1;32mc:\\Users\\marti\\.conda\\envs\\causality_4\\lib\\site-packages\\tensorflow\\python\\eager\\polymorphic_function\\polymorphic_function.py:825\u001b[0m, in \u001b[0;36mFunction.__call__\u001b[1;34m(self, *args, **kwds)\u001b[0m\n\u001b[0;32m    822\u001b[0m compiler \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mxla\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_jit_compile \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnonXla\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m    824\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m OptionalXlaContext(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_jit_compile):\n\u001b[1;32m--> 825\u001b[0m   result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds)\n\u001b[0;32m    827\u001b[0m new_tracing_count \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mexperimental_get_tracing_count()\n\u001b[0;32m    828\u001b[0m without_tracing \u001b[38;5;241m=\u001b[39m (tracing_count \u001b[38;5;241m==\u001b[39m new_tracing_count)\n",
      "File \u001b[1;32mc:\\Users\\marti\\.conda\\envs\\causality_4\\lib\\site-packages\\tensorflow\\python\\eager\\polymorphic_function\\polymorphic_function.py:857\u001b[0m, in \u001b[0;36mFunction._call\u001b[1;34m(self, *args, **kwds)\u001b[0m\n\u001b[0;32m    854\u001b[0m   \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_lock\u001b[38;5;241m.\u001b[39mrelease()\n\u001b[0;32m    855\u001b[0m   \u001b[38;5;66;03m# In this case we have created variables on the first call, so we run the\u001b[39;00m\n\u001b[0;32m    856\u001b[0m   \u001b[38;5;66;03m# defunned version which is guaranteed to never create variables.\u001b[39;00m\n\u001b[1;32m--> 857\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_no_variable_creation_fn(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds)  \u001b[38;5;66;03m# pylint: disable=not-callable\u001b[39;00m\n\u001b[0;32m    858\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_variable_creation_fn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m    859\u001b[0m   \u001b[38;5;66;03m# Release the lock early so that multiple threads can perform the call\u001b[39;00m\n\u001b[0;32m    860\u001b[0m   \u001b[38;5;66;03m# in parallel.\u001b[39;00m\n\u001b[0;32m    861\u001b[0m   \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_lock\u001b[38;5;241m.\u001b[39mrelease()\n",
      "File \u001b[1;32mc:\\Users\\marti\\.conda\\envs\\causality_4\\lib\\site-packages\\tensorflow\\python\\eager\\polymorphic_function\\tracing_compiler.py:148\u001b[0m, in \u001b[0;36mTracingCompiler.__call__\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m    145\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_lock:\n\u001b[0;32m    146\u001b[0m   (concrete_function,\n\u001b[0;32m    147\u001b[0m    filtered_flat_args) \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_maybe_define_function(args, kwargs)\n\u001b[1;32m--> 148\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mconcrete_function\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_flat\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m    149\u001b[0m \u001b[43m    \u001b[49m\u001b[43mfiltered_flat_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcaptured_inputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconcrete_function\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcaptured_inputs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32mc:\\Users\\marti\\.conda\\envs\\causality_4\\lib\\site-packages\\tensorflow\\python\\eager\\polymorphic_function\\monomorphic_function.py:1349\u001b[0m, in \u001b[0;36mConcreteFunction._call_flat\u001b[1;34m(self, args, captured_inputs)\u001b[0m\n\u001b[0;32m   1345\u001b[0m possible_gradient_type \u001b[38;5;241m=\u001b[39m gradients_util\u001b[38;5;241m.\u001b[39mPossibleTapeGradientTypes(args)\n\u001b[0;32m   1346\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (possible_gradient_type \u001b[38;5;241m==\u001b[39m gradients_util\u001b[38;5;241m.\u001b[39mPOSSIBLE_GRADIENT_TYPES_NONE\n\u001b[0;32m   1347\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m executing_eagerly):\n\u001b[0;32m   1348\u001b[0m   \u001b[38;5;66;03m# No tape is watching; skip to running the function.\u001b[39;00m\n\u001b[1;32m-> 1349\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_build_call_outputs(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_inference_function\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[0;32m   1350\u001b[0m forward_backward \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_select_forward_and_backward_functions(\n\u001b[0;32m   1351\u001b[0m     args,\n\u001b[0;32m   1352\u001b[0m     possible_gradient_type,\n\u001b[0;32m   1353\u001b[0m     executing_eagerly)\n\u001b[0;32m   1354\u001b[0m forward_function, args_with_tangents \u001b[38;5;241m=\u001b[39m forward_backward\u001b[38;5;241m.\u001b[39mforward()\n",
      "File \u001b[1;32mc:\\Users\\marti\\.conda\\envs\\causality_4\\lib\\site-packages\\tensorflow\\python\\eager\\polymorphic_function\\atomic_function.py:196\u001b[0m, in \u001b[0;36mAtomicFunction.__call__\u001b[1;34m(self, *args)\u001b[0m\n\u001b[0;32m    194\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m record\u001b[38;5;241m.\u001b[39mstop_recording():\n\u001b[0;32m    195\u001b[0m   \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_bound_context\u001b[38;5;241m.\u001b[39mexecuting_eagerly():\n\u001b[1;32m--> 196\u001b[0m     outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_bound_context\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcall_function\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m    197\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    198\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mlist\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    199\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfunction_type\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mflat_outputs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    200\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    201\u001b[0m   \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m    202\u001b[0m     outputs \u001b[38;5;241m=\u001b[39m make_call_op_in_graph(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28mlist\u001b[39m(args))\n",
      "File \u001b[1;32mc:\\Users\\marti\\.conda\\envs\\causality_4\\lib\\site-packages\\tensorflow\\python\\eager\\context.py:1457\u001b[0m, in \u001b[0;36mContext.call_function\u001b[1;34m(self, name, tensor_inputs, num_outputs)\u001b[0m\n\u001b[0;32m   1455\u001b[0m cancellation_context \u001b[38;5;241m=\u001b[39m cancellation\u001b[38;5;241m.\u001b[39mcontext()\n\u001b[0;32m   1456\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m cancellation_context \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m-> 1457\u001b[0m   outputs \u001b[38;5;241m=\u001b[39m \u001b[43mexecute\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexecute\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m   1458\u001b[0m \u001b[43m      \u001b[49m\u001b[43mname\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mutf-8\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1459\u001b[0m \u001b[43m      \u001b[49m\u001b[43mnum_outputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_outputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1460\u001b[0m \u001b[43m      \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtensor_inputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1461\u001b[0m \u001b[43m      \u001b[49m\u001b[43mattrs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattrs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1462\u001b[0m \u001b[43m      \u001b[49m\u001b[43mctx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1463\u001b[0m \u001b[43m  \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1464\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m   1465\u001b[0m   outputs \u001b[38;5;241m=\u001b[39m execute\u001b[38;5;241m.\u001b[39mexecute_with_cancellation(\n\u001b[0;32m   1466\u001b[0m       name\u001b[38;5;241m.\u001b[39mdecode(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mutf-8\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[0;32m   1467\u001b[0m       num_outputs\u001b[38;5;241m=\u001b[39mnum_outputs,\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m   1471\u001b[0m       cancellation_manager\u001b[38;5;241m=\u001b[39mcancellation_context,\n\u001b[0;32m   1472\u001b[0m   )\n",
      "File \u001b[1;32mc:\\Users\\marti\\.conda\\envs\\causality_4\\lib\\site-packages\\tensorflow\\python\\eager\\execute.py:53\u001b[0m, in \u001b[0;36mquick_execute\u001b[1;34m(op_name, num_outputs, inputs, attrs, ctx, name)\u001b[0m\n\u001b[0;32m     51\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m     52\u001b[0m   ctx\u001b[38;5;241m.\u001b[39mensure_initialized()\n\u001b[1;32m---> 53\u001b[0m   tensors \u001b[38;5;241m=\u001b[39m \u001b[43mpywrap_tfe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mTFE_Py_Execute\u001b[49m\u001b[43m(\u001b[49m\u001b[43mctx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_handle\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mop_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m     54\u001b[0m \u001b[43m                                      \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattrs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_outputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     55\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m core\u001b[38;5;241m.\u001b[39m_NotOkStatusException \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m     56\u001b[0m   \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "for i in range(10):\n",
    "    \n",
    "    # preprocessing of the data, delete transient part\n",
    "    M_df_no_transient, M_pos, M_neg = eliminate_transient(M_df)\n",
    "    MV_df_no_transient, MV_pos, MV_neg = eliminate_transient(MV_df)\n",
    "    MF_df_no_transient, MF_pos, MF_neg = eliminate_transient(MF_df)\n",
    "\n",
    "    # normalisation of datasets\n",
    "    M_data, M_label, M_train_data, M_train_label, M_valid_data, M_valid_label, M_test_data, M_test_label = shuffle_normalise_split_fft_dataset(M_df_no_transient, 0.6, 0.2, 0.2)\n",
    "    MV_data, MV_label, MV_train_data, MV_train_label, MV_valid_data, MV_valid_label, MV_test_data, MV_test_label = shuffle_normalise_split_fft_dataset(MV_df_no_transient, 0.6, 0.2, 0.2)\n",
    "    MF_data, MF_label, MF_train_data, MF_train_label, MF_valid_data, MF_valid_label, MF_test_data, MF_test_label = shuffle_normalise_split_fft_dataset(MF_df_no_transient, 0.6, 0.2, 0.2)\n",
    "    \n",
    "    # set the dataste for training, validation and testing\n",
    "    # the validation is to see if the data is overfitting, in this case, we can use another dataset to check\n",
    "    # the MF data is only for evaluation, therefore, we can use all of it\n",
    "\n",
    "    train_len = min([len(M_train_label),len(MV_train_label)])\n",
    "    train_set, train_label = M_train_data[0:train_len,:], M_train_label[0:train_len]\n",
    "    valid_set, valid_label = np.concatenate((MV_valid_data,MF_valid_data),axis=0), np.concatenate((MV_valid_label,MF_valid_label))\n",
    "\n",
    "    # train the model\n",
    "    training_epoch = 20\n",
    "    fnn_basic_with_fft_model, train_val_history = fnn_training_with_fft(train_set, train_label, valid_set, valid_label, epoch = training_epoch)\n",
    "    # evaluation\n",
    "    log_filename = \"fnn_with_fft.txt\"\n",
    "    plot_the_train_history(train_val_history, filename = f\"Basic fnn with fft {i} epoch {training_epoch}\")\n",
    "    evaluation(fnn_basic_with_fft_model,M_test_data,M_test_label,\"Basic test dataset\",filename=log_filename, model_name=\"Basic fnn with fft\")\n",
    "    evaluation(fnn_basic_with_fft_model,MV_test_data,MV_test_label,\"CDD-based test dataset\",filename=log_filename, model_name=\"Basic fnn with fft\")\n",
    "    evaluation(fnn_basic_with_fft_model,MF_data,MF_label,\"External test dataset\",filename=log_filename, model_name=\"Basic fnn with fft\")\n",
    "\n",
    "    # train the model\n",
    "    training_epoch = 20\n",
    "    train_set, train_label = MV_train_data[0:train_len,:], MV_train_label[0:train_len]\n",
    "    valid_set, valid_label = np.concatenate((M_valid_data,MF_valid_data),axis=0), np.concatenate((M_valid_label,MF_valid_label))\n",
    "\n",
    "    fnn_CDD_with_fft_model, train_val_history = fnn_training_with_fft(train_set, train_label, valid_set, valid_label, epoch = training_epoch)\n",
    "    # evaluation\n",
    "    \n",
    "    plot_the_train_history(train_val_history, filename = f\"CDD-based fnn with fft {i} epoch {training_epoch}\")\n",
    "    evaluation(fnn_CDD_with_fft_model,M_test_data,M_test_label,\"Basic test dataset\",filename=log_filename, model_name=\"CDD-based fnn with fft\")\n",
    "    evaluation(fnn_CDD_with_fft_model,MV_test_data,MV_test_label,\"CDD-based test dataset\",filename=log_filename, model_name=\"CDD-based fnn with fft\")\n",
    "    evaluation(fnn_CDD_with_fft_model,MF_data,MF_label,\"External dataset\",filename=log_filename, model_name=\"CDD-based fnn with fft\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the models and output data to disk\n",
    "fnn_basic_with_fft_model.save(f'fnn_basic_with_fft_model_{i}.h5')\n",
    "fnn_CDD_with_fft_model.save(f'fnn_CDD_with_fft_model_{i}.h5')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "causality_4",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
